{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nWg8AX79c3SZ"
   },
   "source": [
    "# Negative (Proxy) Controls for Unobserved Confounding\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8TMUxMlF-QpJ"
   },
   "source": [
    "Consider the following SEM, where $Y$ is the outcome, $D$ is the treatment, $A$ is some unobserved confounding, and $Q$, $X$, $S$ are the observed covariates. In particular, $Q$ is considered to be the proxy control treatment as it a priori has no effect on the actual outcome $Y$, and $S$ is considered to be the proxy control outcome as it a priori is not affected by the actual treatment $D$. See also [An Introduction to Proximal Causal Learning](https://arxiv.org/pdf/2009.10982.pdf), for more information on this setting.\n",
    "\n",
    "![proxy_dag.png](https://raw.githubusercontent.com/stanford-msande228/winter23/main/proxy_dag.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lAEWlzrqkH-T"
   },
   "source": [
    "Under linearity assumptions, the average treatment effect can be estimated by solving the vector of moment equations:\n",
    "\\begin{align}\n",
    "E\\left[(\\tilde{Y} - \\alpha \\tilde{D} - \\delta \\tilde{S}) \\left(\\begin{aligned}\\tilde{D}\\\\ \\tilde{Q}\\end{aligned}\\right) \\right] = 0\n",
    "\\end{align}\n",
    "where for every variable $V$ we denote with $\\tilde{V} = V - E[V|X]$.\n",
    "\n",
    "When the dimension of the proxy treatment variables $Q$ is larger than the dimension of proxy outcome variables $S$, then the above system of equations is over-identified. In these settings, we first project the \"technical instrument\" variables $\\tilde{V}=(\\tilde{D}, \\tilde{Q})$ onto the space of \"technical treatment\" variables $\\tilde{W}=(\\tilde{D}, \\tilde{S})$ and use the projected $\\tilde{V}$ as a new \"technical instrument\". In particular, we run an OLS regression of $\\tilde{W}$ on $\\tilde{V}$, and define $\\tilde{Z} = E[\\tilde{W}\\mid \\tilde{V}] = B \\tilde{V}$, where the $t$-th row $\\beta_t$ of the matrix $B$ is the OLS coefficient in the regression of $\\tilde{W}_t$ on $\\tilde{V}$. These new variables $\\tilde{Z}$, can also be viewed as engineered technical instrument variables. Then we have the exactly identified system of equations:\n",
    "\\begin{align}\n",
    "E\\left[(\\tilde{Y} - \\alpha \\tilde{D} - \\delta \\tilde{S}) \\tilde{Z} \\right] := E\\left[(\\tilde{Y} - \\alpha \\tilde{D} - \\delta \\tilde{S}) B \\left(\\begin{aligned}\\tilde{D}\\\\ \\tilde{Q}\\end{aligned}\\right) \\right] = 0\n",
    "\\end{align}\n",
    "\n",
    "In fact the solution to this system of equations is numerically equivalent to the following two stage algorithm:\n",
    "- Run OLS of $\\tilde{W}=(\\tilde{D}, \\tilde{S})$ on $\\tilde{V}=(\\tilde{D}, \\tilde{Q})$\n",
    "- Define $\\tilde{Z}$ as the predictions of the OLS model\n",
    "- Run OLS of $\\tilde{Y}$ on $\\tilde{Z}$.\n",
    "This is the well-known Two-Stage-Least-Squares (2SLS) algorithm for instrumental variable regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gIw0pCdbGCVe"
   },
   "outputs": [],
   "source": [
    "# Import relevant packages\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import cross_val_predict, KFold\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.linear_model import LassoCV, LinearRegression\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import patsy\n",
    "import warnings\n",
    "from sklearn.multioutput import MultiOutputRegressor\n",
    "warnings.simplefilter('ignore')\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cCfAVHdkIJ4w"
   },
   "source": [
    "# Analyzing Simulated Data\n",
    "\n",
    "First, let's evaluate the methods on simulated data generated from a linear SEM characterized by the above DAG. For this simulation, we'll set the ATE to 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ATTykpRFAQPC"
   },
   "outputs": [],
   "source": [
    "# generate data from the SCM\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def gen_data(n, ate):\n",
    "    X = np.random.normal(0, 1, size=(n, 10))\n",
    "    A = 2 * X[:, [0]] + np.random.normal(0, 1, size=(n, 1))\n",
    "    Q = 10 * A + 2 * X[:, [0]] + np.random.normal(0, 1, size=(n, 1))\n",
    "    S = 5 * A + X[:, [0]] + np.random.normal(0, 1, size=(n, 1))\n",
    "    D = Q - A + 2 * X[:, [0]] + np.random.normal(0, 1, size=(n, 1))\n",
    "    Y = ate * D + 5 * A + 2 * S + 0.5 * X[:, [0]] + np.random.normal(0, 1, size=(n, 1))\n",
    "    return [X, A, Q, S, D, Y.flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YQa7gj3NB-uY"
   },
   "outputs": [],
   "source": [
    "X, A, Q, S, D, y = gen_data(5000, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HTX374msGCVg"
   },
   "source": [
    "We define the techincal instrument $V=(D, Q)$ and technical treatment $W=(D, S)$ and then it is a matter of solving an instrument variable regression problem with instruments $V$ and treatments $W$ and looking at the first coefficient associated with $D$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7Kc2jGXBGCVh"
   },
   "outputs": [],
   "source": [
    "V = np.hstack([D, Q])  # technical instruments 5000, 2\n",
    "W = np.hstack([D, S])  # technical treatments 5000, 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cDBfCHOdGCVh"
   },
   "source": [
    "### Partialling-Out X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Q5gbLM1cGCVi"
   },
   "outputs": [],
   "source": [
    "modely = make_pipeline(StandardScaler(), LassoCV())\n",
    "modelw = make_pipeline(StandardScaler(), LassoCV())\n",
    "modelv = make_pipeline(StandardScaler(), LassoCV())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "koTICN8tGCVi"
   },
   "outputs": [],
   "source": [
    "resy = y - modely.fit(X, y).predict(X)\n",
    "resV = V - MultiOutputRegressor(modelv).fit(X, V).predict(X)  # residual instrument\n",
    "resW = W - MultiOutputRegressor(modelw).fit(X, W).predict(X)  # residual treatment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "t_6Pp6wlGCVi"
   },
   "source": [
    "### Approach 1: Solving the Moment Equation\n",
    "\n",
    "In this case since $V$ and $W$ have the same dimension, we can just solve the moment equation:\n",
    "\\begin{align}\n",
    "E\\left[(\\tilde{Y} - \\theta'\\tilde{W}) \\tilde{V} \\right] = 0\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zYqYnbNnGCVi"
   },
   "outputs": [],
   "source": [
    "n = resV.shape[0]\n",
    "J = (resV.T @ resW) / n\n",
    "alpha = (resV.T @ resy) / n\n",
    "point = np.linalg.inv(J) @ alpha\n",
    "point[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "klRLTQuwGCVi"
   },
   "source": [
    "### Approach 2: Projecting the Instrument on the Treatment\n",
    "\n",
    "Alternatively, we could have constructed a technical instrument by calculating regression $W$ on $V$ with OLS and using $Z=E[W|V]$ as the new instrument."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0DX6DdhgGCVi"
   },
   "outputs": [],
   "source": [
    "resZ = LinearRegression(fit_intercept=False).fit(resV, resW).predict(resV)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vb_li392GCVj"
   },
   "outputs": [],
   "source": [
    "J = (resZ.T @ resW) / n\n",
    "alpha = (resZ.T @ resy) / n\n",
    "point = np.linalg.inv(J) @ alpha\n",
    "point[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B1mxf76rGCVj"
   },
   "source": [
    "In this case we see that because we started with an \"exactly\" identified system, this projection step doesn't change the result. This is provably always the case."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JaD28NkGGCVj"
   },
   "source": [
    "### Approach 3:  2SLS\n",
    "\n",
    "We can take one step further and use the two stage least squares approach, were we run OLS of $\\tilde{Z}$ on $\\tilde{y}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3ZITM23jGCVj"
   },
   "outputs": [],
   "source": [
    "LinearRegression(fit_intercept=False).fit(resZ, resy).coef_[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-u7fID2KGCVj"
   },
   "source": [
    "We see that again this doesn't change the result, since 2SLS is equivalent to solving the moment condition in Approach 2."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jgG5SDhUGCVj"
   },
   "source": [
    "# With Cross-Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "L2HIw7NHGCVj"
   },
   "outputs": [],
   "source": [
    "def proxydml(X, Q, S, D, y, modely, modelw, modelv, *, nfolds):\n",
    "    '''\n",
    "    DML for the Partially Linear Model setting with cross-fitting\n",
    "\n",
    "    Input\n",
    "    -----\n",
    "    X: the controls\n",
    "    Q: the treatment proxy\n",
    "    S: the outcome proxy\n",
    "    D: the treatment\n",
    "    y: the outcome\n",
    "    modely: the ML model for predicting the outcome y\n",
    "    modelw: the ML model for predicting the technical treatments W=(D, S) from X\n",
    "    modelv: the ML model for predicting the technical instruments V=(D, Q) from X\n",
    "    nfolds: the number of folds in cross-fitting\n",
    "\n",
    "    Output\n",
    "    ------\n",
    "    point: the point estimate of the treatment effect of D on y\n",
    "    yhat: the cross-fitted predictions for the outcome y\n",
    "    What: the cross-fitted predictions for the technical treatments W\n",
    "    Vhat: the cross-fitted predictions for the technical instruments V\n",
    "    resy: the outcome residuals\n",
    "    resW: the treatment residuals\n",
    "    resV: the instrument residuals\n",
    "    '''\n",
    "    W = np.hstack([D, S])  # technical treatments\n",
    "    V = np.hstack([D, Q])  # technical instruments\n",
    "\n",
    "    cv = KFold(n_splits=nfolds, shuffle=True, random_state=123)  # shuffled k-folds\n",
    "    yhat = cross_val_predict(modely, X, y, cv=cv, n_jobs=-1)  # out-of-fold predictions for y\n",
    "    What = cross_val_predict(MultiOutputRegressor(modelw), X, W, cv=cv, n_jobs=-1)\n",
    "    Vhat = cross_val_predict(MultiOutputRegressor(modelv), X, V, cv=cv, n_jobs=-1)\n",
    "\n",
    "    # calculate outcome and treatment residuals\n",
    "    resy = y - yhat\n",
    "    resW = W - What\n",
    "    resV = V - Vhat\n",
    "\n",
    "    # project the residual instruments on the residual treatments\n",
    "    resZ = LinearRegression(fit_intercept=False).fit(resV, resW).predict(resV)\n",
    "\n",
    "    # final stage ols based point estimate and standard error\n",
    "    n = resW.shape[0]\n",
    "    J = (resZ.T @ resW) / n\n",
    "    Jinv = np.linalg.inv(J)\n",
    "    alpha = (resZ.T @ resy) / n\n",
    "    params = Jinv @ alpha\n",
    "    point = params[0]\n",
    "    stderr = None  # implement this as an exercise; see corresponding calculations in IV notebooks\n",
    "\n",
    "    return point, stderr, yhat, What, Vhat, resy, resW, resV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "w9rT2wx4GCVk"
   },
   "outputs": [],
   "source": [
    "def summary(point, stderr, yhat, What, Vhat, resy, resW, resV, X, Q, S, D, y, *, name):\n",
    "    '''\n",
    "    Convenience summary function that takes the results of the DML function\n",
    "    and summarizes several estimation quantities and performance metrics.\n",
    "    '''\n",
    "    lower = point - 1.96 * stderr if stderr is not None else None\n",
    "    upper = point + 1.96 * stderr if stderr is not None else None\n",
    "    W = np.hstack([D, S])\n",
    "    V = np.hstack([D, Q])\n",
    "    return pd.DataFrame({'estimate': point,  # point estimate\n",
    "                         'stderr': stderr,  # standard error\n",
    "                         'lower': lower,  # lower end of 95% confidence interval\n",
    "                         'upper': upper,  # upper end of 95% confidence interval\n",
    "                         'rmse y': np.sqrt(np.mean(resy**2)),  # RMSE of model that predicts outcome y\n",
    "                         'r2 y': 1 - np.mean(resy**2) / np.var(y),\n",
    "                         'rmse W': np.sqrt(np.mean(resW**2)),  # RMSE of model that predicts treatments W\n",
    "                         'avg. r2 W': np.mean(1 - np.mean(resW**2, axis=0) / np.var(W, axis=0)),\n",
    "                         'rmse V': np.sqrt(np.mean(resV**2)),  # RMSE of model that predicts treatments V\n",
    "                         'avg. r2 V': np.mean(1 - np.mean(resV**2, axis=0) / np.var(V, axis=0)),\n",
    "                         }, index=[name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "90sA4huMGCVk"
   },
   "outputs": [],
   "source": [
    "X, A, Q, S, D, y = gen_data(5000, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9r0XMo_WGCVk"
   },
   "outputs": [],
   "source": [
    "cv = KFold(n_splits=5, shuffle=True, random_state=123)\n",
    "modely = make_pipeline(StandardScaler(), LassoCV(cv=cv))\n",
    "modelw = make_pipeline(StandardScaler(), LassoCV(cv=cv))\n",
    "modelv = make_pipeline(StandardScaler(), LassoCV(cv=cv))\n",
    "res = proxydml(X, Q, S, D, y, modely, modelw, modelv, nfolds=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kwvZStUKGCVk"
   },
   "outputs": [],
   "source": [
    "summary(*res, X, Q, S, D, y, name='lassocv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YJ2QbI4-GCVk"
   },
   "outputs": [],
   "source": [
    "from joblib import Parallel, delayed\n",
    "\n",
    "\n",
    "def exp(it):\n",
    "    np.random.seed(it)\n",
    "    X, A, Q, S, D, y = gen_data(5000, 2)\n",
    "    cv = KFold(n_splits=5, shuffle=True, random_state=it)\n",
    "    lassoy = make_pipeline(StandardScaler(), LassoCV(cv=cv))\n",
    "    lassow = make_pipeline(StandardScaler(), LassoCV(cv=cv))\n",
    "    lassov = make_pipeline(StandardScaler(), LassoCV(cv=cv))\n",
    "    res = proxydml(X, Q, S, D, y, lassoy, lassow, lassov, nfolds=3)\n",
    "    point = res[0]\n",
    "    stderr = 0 if res[1] is None else res[1]  # this will be fixed once stderr is implemented!\n",
    "    return point, point - 1.96 * stderr, point + 1.96 * stderr\n",
    "\n",
    "\n",
    "results = Parallel(n_jobs=-1, verbose=3)(delayed(exp)(it) for it in range(100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Xq8E10lGGCVk"
   },
   "outputs": [],
   "source": [
    "points, lowers, uppers = zip(*results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wTBkvJkvGCVk"
   },
   "outputs": [],
   "source": [
    "coverage = np.mean((np.array(lowers) <= 2) & (2 <= np.array(uppers)))\n",
    "coverage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PNcRBZORGCVk"
   },
   "outputs": [],
   "source": [
    "np.std(points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-XRJ06vKGCVl"
   },
   "outputs": [],
   "source": [
    "np.mean(points)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lGrVzQVlI2IH"
   },
   "source": [
    "## Real Data - Effects of Smoking on Birth Weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BS0NRRjy6lPq"
   },
   "source": [
    "In this study, we will be studying the effects of smoking on baby weight. Base on the domain knowledge, we will consider the following setup:\n",
    "\n",
    "Outcome ($Y$): baby weight\n",
    "\n",
    "Treatment ($D$): smoking\n",
    "\n",
    "Unobserved condounding ($A$): family income\n",
    "\n",
    "The observed covariates are put in to 3 groups:\n",
    "\n",
    "\n",
    "*   Proxy treatment control ($Q$): mother's education\n",
    "*   Proxy outcome control ($S$): parity (total number of previous pregnancies)\n",
    "*   Other observed covariates ($X$): mother's race and age and infant sex\n",
    "\n",
    "\n",
    "Education serves as a proxy treatment control $Q$ because it reflects unobserved confounding due to household income $A$ but has no direct medical effect on birth weight $Y$. Parity and sex serve as a proxy outcome control $S$ because family size reflects household income $A$ but is not directly caused by smoking $D$ or education $Q$.\n",
    "\n",
    "A description of the data used can be found [here](https://www.stat.berkeley.edu/users/statlabs/data/babies.readme)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "165kCbYoY35y"
   },
   "outputs": [],
   "source": [
    "data = pd.read_csv('https://www.stat.berkeley.edu/users/statlabs/data/babies23.data', sep='\\\\s+')\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pTe7hFMemlFP"
   },
   "outputs": [],
   "source": [
    "# Filter data so to exclude entries where income, number of cigarettes smoked,\n",
    "# parity, and baby weight are not asked or not known\n",
    "data = data[data.wt != 999]\n",
    "data = data[data.parity != 99]\n",
    "data = data[data.parity != 9]\n",
    "data = data[np.logical_and(data.number != 98, data.number != 99)]\n",
    "data = data[np.logical_and(data.inc != 98, data.inc != 99)]\n",
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4E2f9SS2ZfXU"
   },
   "outputs": [],
   "source": [
    "X = np.array(patsy.dmatrix('0 + C(race) + age + C(sex)', data))\n",
    "D = np.array(patsy.dmatrix('0 + number', data))\n",
    "Q = np.array(patsy.dmatrix('0 + C(ed)', data))\n",
    "S = np.array(patsy.dmatrix('0 + parity', data))\n",
    "A = np.array(patsy.dmatrix('0 + inc', data))\n",
    "y = np.array(patsy.dmatrix('0 + wt', data)).flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pUAJG53t110L"
   },
   "outputs": [],
   "source": [
    "Q.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-7oS-GI0GCVl"
   },
   "outputs": [],
   "source": [
    "cv = KFold(n_splits=5, shuffle=True, random_state=123)\n",
    "modely = make_pipeline(StandardScaler(), LassoCV(cv=cv))\n",
    "modelw = make_pipeline(StandardScaler(), LassoCV(cv=cv))\n",
    "modelv = make_pipeline(StandardScaler(), LassoCV(cv=cv))\n",
    "res = proxydml(X, Q, S, D, y, modely, modelw, modelv, nfolds=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ph4FDYaQGCVl"
   },
   "outputs": [],
   "source": [
    "summary(*res, X, Q, S, D, y, name='lassocv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Y56iqvenGCVv"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": [
    {
     "file_id": "https://github.com/stanford-msande228/winter23/blob/main/Proxy_Controls.ipynb",
     "timestamp": 1702788346493
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
